Training LoopΒΆ
Critic Update: The critic is updated multiple times (crit_cycles) for each update of the generator. This helps improve the stability of the GAN.
- Critic loss is calculated as the difference between the average scores for real and fake images, plus the gradient penalty term.
- The critic is trained by minimizing this loss.
Generator Update: After the critic has been updated, the generator is updated. The generator loss is the negative mean of the criticβs predictions on the fake images.
- The generator is trained by maximizing the critic's confusion (i.e., fooling the critic into classifying fake images as real).
Visualization and Checkpoints: Periodically, the training loop displays generated and real images, logs losses, and saves the model checkpoints.
# Training loop
for epoch in range(n_epochs):
for real, _ in tqdm(dataloader):
cur_bs = len(real) # 128
real = real.to(device)
## Critic
mean_crit_loss = 0
for _ in range(crit_cycles):
# zeroing gradient of the optimizer
crit_opt.zero_grad()
noise = gen_noise(cur_bs, z_dim)
fake = gen(noise)
# detaching for not affecting the parameters of the generator
crit_fake_pred = crit(fake.detach())
crit_real_pred = crit(real)
# alpha vector (numbers size of the batch)
alpha = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True) # 128 x 1 x 1 x 1
# calculating gradient penalty
gp = get_gp(real, fake.detach(), crit, alpha)
# calculating loss
crit_loss = crit_fake_pred.mean() - crit_real_pred.mean() + gp
#.item - taking only the number from the tensor
mean_crit_loss += crit_loss.item() / crit_cycles
# optimizer backpropagation
crit_loss.backward(retain_graph=True)
crit_opt.step()
# list of losses values
crit_losses += [mean_crit_loss]
## Generator
# zeroing gradient of the optimizer
gen_opt.zero_grad()
# creating noise 128 x 200
noise = gen_noise(cur_bs, z_dim)
# passing noise through generator
fake = gen(noise)
# passing them through critic
crit_fake_pred = crit(fake)
# negative of the pred of the critic
gen_loss = -crit_fake_pred.mean()
# backpropagation
gen_loss.backward()
# updating the parameters of the generator
gen_opt.step()
gen_losses+=[gen_loss.item()]
if cur_step % save_step == 0 and cur_step > 0:
print('Saving checkpoint:', cur_step, save_step)
# best to save the files with the different names fe. nr of epoch
save_checkpoint('latest')
if (cur_step % show_step == 0 and cur_step > 0):
show(fake, wandbactivation=1, name='fake')
show(real, wandbactivation=1, name='real')
gen_mean = sum(gen_losses[-show_step:]) / show_step
crit_mean = sum(crit_losses[-show_step:]) / show_step
print(f'Epoch: {epoch}, step: {cur_step}, Generator loss: {gen_loss}, Critic loss: {crit_loss}')
plt.plot(range(len(gen_losses)),
torch.Tensor(gen_losses),
label='Generator loss')
plt.plot(range(len(crit_losses)),
torch.Tensor(crit_losses),
label='Critic loss')
plt.ylim(-200,200)
plt.legend()
plt.show()
cur_step += 1
0%| | 0/391 [00:00<?, ?it/s]
Saving checkpoint: 10045 35 Saved checkpoint
Epoch: 0, step: 10045, Generator loss: 6.669975280761719, Critic loss: -3.8914854526519775
Saving checkpoint: 10080 35 Saved checkpoint
Epoch: 0, step: 10080, Generator loss: 10.462503433227539, Critic loss: -4.360108852386475
Saving checkpoint: 10115 35 Saved checkpoint
Epoch: 0, step: 10115, Generator loss: 4.7965593338012695, Critic loss: -3.7804861068725586
Saving checkpoint: 10150 35 Saved checkpoint
Epoch: 0, step: 10150, Generator loss: 4.233477592468262, Critic loss: -5.056338787078857
Saving checkpoint: 10185 35 Saved checkpoint
Epoch: 0, step: 10185, Generator loss: 5.628297328948975, Critic loss: -4.169041156768799
Saving checkpoint: 10220 35 Saved checkpoint
Epoch: 0, step: 10220, Generator loss: 12.158222198486328, Critic loss: -4.853008270263672
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[20], line 6 4 for real, _ in tqdm(dataloader): 5 cur_bs = len(real) # 128 ----> 6 real = real.to(device) 9 ## Critic 10 mean_crit_loss = 0 KeyboardInterrupt:
10000 / 128 = 78.125 - 79 steps per epoch
Summary of the GAN Workflow:ΒΆ
- Noise Generation: Random noise is generated and passed into the generator.
- Image Generation: The generator transforms the noise into a synthetic image.
- Critic Evaluation: The critic evaluates both real and fake images, trying to distinguish between them.
Loss Calculation:
- The generator tries to minimize the critic's ability to detect fakes.
- The critic tries to maximize the distinction between real and fake images while ensuring a smooth gradient.
Training Loop: Alternates between training the generator and the critic, improving the quality of the generated images over time. This code builds a GAN using best practices, such as using gradient penalties and properly balancing critic and generator updates, making it suitable for generating high-quality images.